Delete tree node [DFS, DP]

Time: O(N); Space: O(N); medium

A tree rooted at node 0 is given as follows:

  • The number of nodes is nodes;

  • The value of the i-th node is value[i];

  • The parent of the i-th node is parent[i].

Remove every subtree whose sum of values of nodes is zero.

After doing so, return the number of nodes remaining in the tree.

Example 1:

Input: nodes = 7, parent = [-1,0,0,1,2,2,2], value = [1,-2,4,0,-2,-1,-1]

Output: 2

Constraints:

  • 1 <= nodes <= 10^4

  • -10^5 <= value[i] <= 10^5

  • len(parent) == nodes

  • parent[0] == -1 which indicates that 0 is the root.

1. DFS

[1]:
import collections

class Solution1(object):
    """
    Time: O(N)
    Space: O(N)
    """
    def deleteTreeNodes(self, nodes, parent, value):
        """
        :type nodes: int
        :type parent: List[int]
        :type value: List[int]
        :rtype: int
        """
        def dfs(value, children, x):
            total, count = value[x], 1
            for y in children[x]:
                t, c = dfs(value, children, y)
                total += t
                count += c if t else 0
            return total, count if total else 0

        children = collections.defaultdict(list)
        for i, p in enumerate(parent):
            if i:
                children[p].append(i)

        return dfs(value, children, 0)[1]
[2]:
s = Solution1()

nodes = 7
parent = [-1,0,0,1,2,2,2]
value = [1,-2,4,0,-2,-1,-1]
assert s.deleteTreeNodes(nodes, parent, value) == 2

2. DFS

[3]:
import collections

class Solution2(object):
    def deleteTreeNodes(self, nodes, parent, value):
        """
        :type nodes: int
        :type parent: List[int]
        :type value: List[int]
        :rtype: int
        """
        tree = collections.defaultdict(list)

        for i, v in enumerate(parent):
            tree[v].append(i)

        def dfs(cur):
            sumAll, cnt = value[cur], 1
            for ne in tree[cur]:
                t, c = dfs(ne)
                sumAll += t
                cnt += c
            return sumAll, cnt if sumAll else 0

        return dfs(0)[1]
[4]:
s = Solution2()

nodes = 7
parent = [-1,0,0,1,2,2,2]
value = [1,-2,4,0,-2,-1,-1]
assert s.deleteTreeNodes(nodes, parent, value) == 2

3. DP

[5]:
class Solution3(object):
    def deleteTreeNodes(self, nodes, parent, value):
        """
        :type nodes: int
        :type parent: List[int]
        :type value: List[int]
        :rtype: int
        """
        # assuming parent[i] < i for all i > 0
        result = [1] * nodes

        for i in reversed(range(1, nodes)):
            value[parent[i]] += value[i]
            result[parent[i]] += result[i] if value[i] else 0

        return result[0]
[6]:
s = Solution3()

nodes = 7
parent = [-1,0,0,1,2,2,2]
value = [1,-2,4,0,-2,-1,-1]
assert s.deleteTreeNodes(nodes, parent, value) == 2

4.

[7]:
class Solution4(object):
    def deleteTreeNodes(self, nodes, parent, value):

        for i in range(nodes-1, 0, -1):
            value[parent[i]] += value[i]

        zeros, isZero = 0, [0] * nodes

        for i in range(nodes):
            if parent[i] > 0 and isZero[parent[i]] or value[i] == 0:
                isZero[i] = 1
                zeros += 1

        return nodes - zeros
[8]:
s = Solution4()

nodes = 7
parent = [-1,0,0,1,2,2,2]
value = [1,-2,4,0,-2,-1,-1]
assert s.deleteTreeNodes(nodes, parent, value) == 2